import copy
import json
import os
import pickle
from math import floor

import numpy as np
import torch


# ara= torch.randn(10)*4
# print(ara)
# ret= map_to_exact_discrete(ara)
# print(ret)
from matplotlib import pyplot as plt
from numpy import uint8

from CausalMNISTAddition.DigitImageGeneration.mnist_image_generation import plot_trained_digits
from Causal_Train_By_Components.Causal_TrainGraph import set_trainGraph
from Causal_Train_By_Components.GAN_Evaluation.EvaluateCausalGAN import evaluate_after_epochs, \
    get_expected_loss_interventions, get_expected_loss_countefactuals, get_observational_loss, \
    get_expected_true_cf
from Causal_Partial_Mnist.CausalGraph_Mnist import set_mnist_nonId_newgraph, getdoKey, set_mnist_random_graph
from CausalMNISTAddition.GroundTruth.Synthetic_Distribution_Mnist import get_synthetic_dist, get_bayesian_network, \
    save_queryscm, get_cond_synthetic_dist, get_intv_dist
from CausalMNISTAddition.GroundTruth.True_Counterfactuals_Mnist import get_cf_dist, get_cf_from_intervs
from Causal_Partial_Mnist.RejectionSampling_Optimized import rejection_sampling_optimized
from Causal_Partial_Mnist.mnistControllerModel import get_generators, get_generated_labels
from CausalMNISTAddition.mnistInterventionalDistributions import compare_interventions, compare_conditionals_within
import pyAgrum as gum

from CausalTwoDiscrimMechTrain.ConstantFunctions import asKey
from CausalTwoDiscrimMechTrain.ControllerConstants import map_dictfill_to_discrete, map_fill_to_discrete, \
    generate_permutations
from CausalTwoDiscrimMechTrain.Counterfactuals.CounterfactualDistribution import get_cf_samples
from CausalTwoDiscrimMechTrain.Counterfactuals.RejectionSampling import rejection_sampling
from CausalTwoDiscrimMechTrain.DistributionComparison import match_with_true_dist, get_distributions_from_samples, \
    get_joint_distributions_from_samples
from CausalTwoDiscrimMechTrain.Experiment_Class import Experiment
import torchvision.transforms as transforms

from CausalTwoDiscrimMechTrain.ProbabilityCalculation import calculate_TVD, calculate_KL



def check_query(Exp):
    feat = "feature"
    cur_obs= ["Ycolor"]
    cur_intv_query = {"X1": 1, "X2": 9}
    cur_evidence = {"X1p": 0, "X2p": 0}
    true_cf_dist = get_cf_dist(Exp, cur_obs, cur_intv_query, cur_evidence, "testing_cf", load_dist=True)

    query_str = getdoKey(cur_obs, dict(cur_intv_query))  # getting the scm saving file name
    obs_dist = get_intv_dist(Exp, cur_obs, dict(cur_intv_query),
                             query_str)  # getting the obs distribution of intv variables
    tvd = calculate_TVD(true_cf_dist, obs_dist, doPrint=False)
    print(tvd)



    cfquery = Exp.cf_queries[0]
    evidence_list= [evidence for evidence in cfquery["evidence"]]

    all_posterior_label, all_posterior_latent, all_gumbel_noise = rejection_sampling_optimized(Exp, label_generators,
                                                                                               Exp.Synthetic_Sample_Size,
                                                                                               evidence_list,
                                                                                               max_rejections=0,
                                                                                               warn=100)

    kev = asKey(cur_evidence)
    posterior_label, posterior_latent, gumbel_noise = all_posterior_label[kev], all_posterior_latent[kev], \
                                                      all_gumbel_noise[kev]

    cf_all_labels_dict = get_generated_labels(Exp, label_generators, posterior_label, posterior_latent,
                                              cur_intv_query, cur_obs, Exp.Synthetic_Sample_Size, gumbel_noise=gumbel_noise)
    cf_samples = map_dictfill_to_discrete(Exp, cf_all_labels_dict, cur_obs)

    true_cf_dist = get_cf_dist(Exp, cur_obs, cur_intv_query, cur_evidence, cfquery["expr"], load_dist=True)
    cf_tvd, cf_kl = match_with_true_dist(Exp, cur_obs, cf_samples, true_cf_dist, feat,
                                         doPrint=False)  # get it from scm


    fake_cf_dist= get_joint_distributions_from_samples(Exp, cur_obs, cf_samples, feat)

    #
    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, cur_intv_query, cur_obs,
                                                 Exp.Synthetic_Sample_Size)
    generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, cur_obs)
    fake_intv_dist= get_joint_distributions_from_samples(Exp, cur_obs, generated_labels_full, feat)

    tvd = calculate_TVD(fake_cf_dist, fake_intv_dist, doPrint=False)
    print("cf vs intv", tvd)



    return




def get_conditional_sample_for_images(Exp, label_generators):
    obs_vars=["X1", "X2","W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]

    intv_key= {}
    true_bn, _ = get_bayesian_network(Exp, intv_key, load_scm=1)
    # _, _, _, true_dist_dict = get_synthetic_dist(Exp, obs_vars, true_bn["feature"])
    _, _, _, true_dist_dict = get_cond_synthetic_dist(["W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"],["X1", "X2"], Exp.label_names, true_bn["feature"])
    # tempo_dict= dict(sorted(true_dist_dict.items(), key=lambda item: item[1], reverse=True))


    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, obs_vars, Exp.Synthetic_Sample_Size)
    generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, obs_vars)


    sample_dict={}

    for row in generated_labels_full:
        x,y =row[0], row[1]
        key= tuple((x,y))
        if key not in sample_dict:
            sample_dict[key]=[]
        # if len(sample_dict[key])<4:
        sample_dict[key].append(torch.tensor(row).view(1,-1))


    # obs_conf=(0,8)
    # obs_conf=(1,4)
    obs_conf=(0,1)
    all_samples = sample_dict[obs_conf]
    randices= torch.randint(0, len(all_samples), (10,)).tolist()

    samples= [all_samples[idx] for idx in randices]

    sample_dict= dict(sorted(sample_dict.items()))
    result={"obs_comb":[], "prob": [], "loss":[]}
    losses={}
    prob_dict={}

    for key in sample_dict:
        if key != obs_conf:
            continue
        print("key:",key)
        gen_labels= sample_dict[key]
        gen_labels = torch.cat(gen_labels, dim=0)
        uniques, _, counts = torch.unique(gen_labels, sorted=True, return_inverse=True, return_counts=True, dim=0)

        _, indices = torch.sort(counts, dim=0, descending=True)
        for ind  in indices:
            ky= uniques[ind].view(1,-1)
            cnt= counts[ind]
            fake_prob=counts[ind]/gen_labels.shape[0]
            print(ky, cnt, "fake prob:", fake_prob )
            # print("fake:",cnt/generated_labels_full.shape[0], " true dist", true_dist_dict[tuple(ky.tolist())])
            dd=tuple(ky.tolist()[0])
            loss= abs(counts[ind]/gen_labels.shape[0]-true_dist_dict[dd])
            print("loss",loss)
            result["obs_comb"].append(ky.tolist())
            result["loss"].append(loss.item())
            result["prob"].append(fake_prob.item())
            losses[tuple(ky.tolist()[0])]=loss.item()
            prob_dict[tuple(ky.tolist()[0])]= fake_prob.item()

            # print()

        # print(f"matched  {tot} out of {len(sample_dict[key])} proportion: {tot/len(sample_dict[key])}")


    chosen_samples={"obs_comb":[], "prob": [], "loss":[]}
    for ss in samples:
        for idx, jj in enumerate(result["obs_comb"]):

            l1= ss.tolist()[0]
            l2= jj[0]


            if tuple(l1)==tuple(l2):
                print(ss, result["prob"][idx], result["loss"][idx])
                chosen_samples["obs_comb"].append(l1)
                chosen_samples["prob"].append(result["prob"][idx])
                chosen_samples["loss"].append(result["loss"][idx])


    print("---")


    _, indices = torch.sort(torch.tensor(chosen_samples["prob"]), descending=True)

    l1 = [chosen_samples["obs_comb"][id]  for id in indices]
    l2 = [chosen_samples["prob"][id]   for id in indices]
    l3 = [chosen_samples["loss"][id]   for id in indices]


    # l1,l2,l3= chosen_samples["obs_comb"], chosen_samples["prob"], chosen_samples["loss"]
    # chosen_samples["prob"], chosen_samples["obs_comb"], chosen_samples["loss"]  =  zip(*sorted(zip(l2, l1,l3)))

    for idx,_ in enumerate(l1):
        print(l1[idx],  l2[idx], l3[idx])


    save_res={"obs_comb":l1, "prob": l2, "loss":l3}

    file_name ="prob_save_file_path"
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(save_res))

    return




def get_highest_conditional_sample_for_images(Exp, label_generators):
    obs_vars=["X1", "X2","W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]

    # intv_key= {"X1":1, "X2":4}
    intv_key= {}
    true_bn, _ = get_bayesian_network(Exp, intv_key, load_scm=1)
    # _, _, _, true_dist_dict = get_synthetic_dist(Exp, obs_vars, true_bn["feature"])
    _, _, _, true_dist_dict = get_cond_synthetic_dist(["W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"],["X1", "X2"], Exp.label_names, true_bn["feature"])
    # tempo_dict= dict(sorted(true_dist_dict.items(), key=lambda item: item[1], reverse=True))


    # perms = generate_permutations([2, 9]).tolist()
    # obs_key_val = [dict(zip(["X1", "X2"], comb)) for comb in perms]

    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, obs_vars, Exp.Synthetic_Sample_Size)
    generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, obs_vars)


    sample_dict={}

    for row in generated_labels_full:
        x,y =row[0], row[1]
        key= tuple((x,y))
        if key not in sample_dict:
            sample_dict[key]=[]
        # if len(sample_dict[key])<4:
        sample_dict[key].append(torch.tensor(row).view(1,-1))


    sample_dict= dict(sorted(sample_dict.items()))
    result={"obs_comb":[], "prob": [], "loss":[]}
    losses={}
    prob_dict={}

    for key in sample_dict:

        print("key:",key)
        gen_labels= sample_dict[key]
        gen_labels = torch.cat(gen_labels, dim=0)
        uniques, _, counts = torch.unique(gen_labels, sorted=True, return_inverse=True, return_counts=True, dim=0)

        # combine = torch.cat([counts.view(-1,1),uniques], dim=1)
        _, indices = torch.sort(counts, dim=0, descending=True)
        for ind  in indices:
            ky= uniques[ind].view(1,-1)

            cnt= counts[ind]
            # print(ky, cnt, "fake prob:",  counts[ind]/gen_labels.shape[0], " true dist", true_dist_dict[tuple(ky.tolist())])
            fake_prob=counts[ind]/gen_labels.shape[0]
            print(ky, cnt, "fake prob:", fake_prob )
            # print("fake:",cnt/generated_labels_full.shape[0], " true dist", true_dist_dict[tuple(ky.tolist())])
            dd=tuple(ky.tolist()[0])
            loss= abs(counts[ind]/gen_labels.shape[0]-true_dist_dict[dd])
            print("loss",loss)
            result["obs_comb"].append(ky.tolist())
            result["loss"].append(loss.item())
            result["prob"].append(fake_prob.item())
            losses[tuple(ky.tolist()[0])]=loss.item()
            prob_dict[tuple(ky.tolist()[0])]= fake_prob.item()

            # print()

        # print(f"matched  {tot} out of {len(sample_dict[key])} proportion: {tot/len(sample_dict[key])}")



    for idx,_ in enumerate(result["obs_comb"]):
        print(result["obs_comb"][idx],  result["prob"][idx], result["loss"][idx])



    file_name ="prob_save_file_path"
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(result))

    return

def get_highest_interventional_sample_for_images(Exp, label_generators):
    obs_vars=["X1", "X2","W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]
    intv_vars=["X1","X2"]
    perms = generate_permutations([Exp.label_dim[lb]["feature"] for lb in intv_vars])
    key_vals = [dict(zip(intv_vars, comb)) for comb in perms]

    result = {"obs_comb": [], "loss": []}
    for intv_key in key_vals:
        true_bn, _ = get_bayesian_network(Exp, intv_key, load_scm=1)
        _, _, _, true_dist_dict = get_synthetic_dist(Exp, obs_vars, true_bn["feature"])
        # _, _, _, true_dist_dict = get_cond_synthetic_dist(["W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"],["X1", "X2"], Exp.label_names, true_bn["feature"])
        tempo_dict= dict(sorted(true_dist_dict.items(), key=lambda item: item[1], reverse=True))


        generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, obs_vars, Exp.Synthetic_Sample_Size)
        generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, obs_vars)


        sample_dict={}

        for row in generated_labels_full:
            x,y =row[0], row[1]
            key= tuple((x,y))
            if key not in sample_dict:
                sample_dict[key]=[]
            # if len(sample_dict[key])<4:
            sample_dict[key].append(row)

        sample_dict= dict(sorted(sample_dict.items()))



        for key in sample_dict:
            print("key:",key)
            gen_labels= torch.tensor(sample_dict[key])
            uniques, _, counts = torch.unique(torch.tensor(gen_labels), sorted=True, return_inverse=True, return_counts=True, dim=0)

            # combine = torch.cat([counts.view(-1,1),uniques], dim=1)
            _, indices = torch.sort(counts, dim=0, descending=True)
            for ind  in indices[0:1]:
                ky= uniques[ind]
                cnt= counts[ind]
                # print(ky, cnt, "fake prob:",  counts[ind]/gen_labels.shape[0], " true dist", true_dist_dict[tuple(ky.tolist())])
                print(ky, cnt, "fake prob:",  counts[ind]/gen_labels.shape[0])
                # print("fake:",cnt/generated_labels_full.shape[0], " true dist", true_dist_dict[tuple(ky.tolist())])
                # loss= abs(counts[ind]/gen_labels.shape[0]-true_dist_dict[tuple(ky.tolist())])
                loss= abs(cnt/generated_labels_full.shape[0] -true_dist_dict[tuple(ky.tolist())])
                print("loss",loss)
                result["obs_comb"].append(ky.tolist())
                result["loss"].append(loss.item())

            # print()

        # print(f"matched  {tot} out of {len(sample_dict[key])} proportion: {tot/len(sample_dict[key])}")



    file_name ="prob_save_file_path"
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(result))

    return


def get_interventional_sample_for_images(Exp, label_generators):
    obs_vars=["X1", "X2","W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]
    intv_vars=["X1","X2"]
    perms = generate_permutations([Exp.label_dim[lb]["feature"] for lb in intv_vars])
    key_vals = [dict(zip(intv_vars, comb)) for comb in perms]

    result = {"obs_comb": [], "loss": []}

    obs_conf = (1, 4)
    intv_key={"X1":1,"X2":4}
    # obs_conf=(0,8)
    # intv_key={"X1":0,"X2":8}
    num_samples= 10


    true_bn, _ = get_bayesian_network(Exp, intv_key, load_scm=1)
    _, _, _, true_dist_dict = get_synthetic_dist(Exp, obs_vars, true_bn["feature"])
    # _, _, _, true_dist_dict = get_cond_synthetic_dist(["W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"],["X1", "X2"], Exp.label_names, true_bn["feature"])
    tempo_dict= dict(sorted(true_dist_dict.items(), key=lambda item: item[1], reverse=True))


    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, obs_vars, Exp.Synthetic_Sample_Size)
    generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, obs_vars)


    sample_dict={}

    for row in generated_labels_full:
        x,y =row[0], row[1]
        key= tuple((x,y))
        if key not in sample_dict:
            sample_dict[key]=[]
        # if len(sample_dict[key])<4:
        # sample_dict[key].append(row)
        sample_dict[key].append(torch.tensor(row).view(1,-1))



    all_samples = sample_dict[obs_conf]
    randices = torch.randint(0, len(all_samples), (num_samples,)).tolist()
    samples = [all_samples[idx] for idx in randices]

    sample_dict= dict(sorted(sample_dict.items()))
    result = {"obs_comb": [], "prob": [], "loss": []}
    losses = {}
    prob_dict = {}


    for key in sample_dict:
        if key != obs_conf:
            continue
        print("key:",key)
        # gen_labels= torch.tensor(sample_dict[key])
        gen_labels = sample_dict[key]
        gen_labels = torch.cat(gen_labels, dim=0)
        uniques, _, counts = torch.unique(torch.tensor(gen_labels), sorted=True, return_inverse=True, return_counts=True, dim=0)

        # combine = torch.cat([counts.view(-1,1),uniques], dim=1)
        _, indices = torch.sort(counts, dim=0, descending=True)
        for ind  in indices:
            ky= uniques[ind].view(1,-1)
            cnt= counts[ind]
            fake_prob= counts[ind]/gen_labels.shape[0]
            print(ky, cnt, "fake prob:",fake_prob)
            dd=tuple(ky.tolist()[0])
            loss= abs(cnt/generated_labels_full.shape[0] -true_dist_dict[dd])
            print("loss",loss)
            result["obs_comb"].append(ky.tolist())
            result["loss"].append(loss.item())
            result["prob"].append(fake_prob.item())



    chosen_samples = {"obs_comb": [], "prob": [], "loss": []}
    for ss in samples:
        for idx, jj in enumerate(result["obs_comb"]):
            l1 = ss.tolist()[0]
            l2 = jj[0]
            if tuple(l1) == tuple(l2):
                print(ss, result["prob"][idx], result["loss"][idx])
                chosen_samples["obs_comb"].append(l1)
                chosen_samples["prob"].append(result["prob"][idx])
                chosen_samples["loss"].append(result["loss"][idx])

    print("---")

    _, indices = torch.sort(torch.tensor(chosen_samples["prob"]), descending=True)

    l1 = [chosen_samples["obs_comb"][id] for id in indices]
    l2 = [chosen_samples["prob"][id] for id in indices]
    l3 = [chosen_samples["loss"][id] for id in indices]

    # l1,l2,l3= chosen_samples["obs_comb"], chosen_samples["prob"], chosen_samples["loss"]
    # chosen_samples["prob"], chosen_samples["obs_comb"], chosen_samples["loss"]  =  zip(*sorted(zip(l2, l1,l3)))

    for idx, _ in enumerate(l1):
        print(l1[idx], l2[idx], l3[idx])

    save_res = {"obs_comb": l1, "prob": l2, "loss": l3}


    file_name ="prob_save_file_path"
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(save_res))

    return


def get_cf_samples_for_image(Exp, label_generators):
    result = {"obs_comb": [], "prob": [], "loss": []}

    obs_vars = ["X1", "X2", "W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]

    evidence={"X1p":1, "X2p":6}
    intv_key={"X1":1, "X2":5}
    obs_conf=(1,5)
    # obs_conf = (1, 4)
    num_samples=20

    feat = "feature"

    n_samples = Exp.Synthetic_Sample_Size

    cf_list = [
        {"intv": ["X1", "X2"], "evid": ["X1p", "X2p"], "expr": "P(V|do(X1,X2),X1p, X2p)"}]
    cf_queries = []
    for cf in cf_list:
        perms = generate_permutations([Exp.label_dim[lb]["feature"] for lb in cf["intv"]]).tolist()

        intv_key_val = [dict(zip(cf["intv"], comb)) for comb in perms]

        perms = generate_permutations([Exp.label_dim[lb]["feature"] for lb in cf["evid"]]).tolist()
        ev_key_val = [dict(zip(cf["evid"], comb)) for comb in perms]

        cf_queries.append({"obs": obs_vars, "intervs": intv_key_val, "evidence": ev_key_val, "expr": cf["expr"]})

    evidence_list= [evidence for evidence in cfquery["evidence"]]
    all_posterior_label, all_posterior_latent, all_gumbel_noise = rejection_sampling_optimized(Exp, label_generators,
                                                                                               n_samples, evidence_list,
                                                                                               max_rejections=0,
                                                                                               warn=100)
    kev = asKey(evidence)
    posterior_label, posterior_latent, gumbel_noise = all_posterior_label[kev], all_posterior_latent[kev], \
                                                      all_gumbel_noise[kev]

    cf_all_labels_dict = get_generated_labels(Exp, label_generators, posterior_label, posterior_latent,
                                              intv_key, obs_vars, n_samples, gumbel_noise=gumbel_noise)
    cf_samples = map_dictfill_to_discrete(Exp, cf_all_labels_dict, obs_vars)


    #

    sample_dict = {}

    for row in cf_samples:
        x, y = row[0], row[1]
        key = tuple((x, y))
        if key not in sample_dict:
            sample_dict[key] = []
        # if len(sample_dict[key])<4:
        sample_dict[key].append(torch.tensor(row).view(1, -1))


    all_samples = sample_dict[obs_conf]
    randices = torch.randint(0, len(all_samples), (num_samples,)).tolist()

    samples = [all_samples[idx] for idx in randices]

    sample_dict = dict(sorted(sample_dict.items()))
    result = {"obs_comb": [], "prob": [], "loss": []}
    losses = {}
    prob_dict = {}

    for key in sample_dict:
        if key != obs_conf:
            continue
        print("key:", key)
        gen_labels = sample_dict[key]
        gen_labels = torch.cat(gen_labels, dim=0)
        uniques, _, counts = torch.unique(gen_labels, sorted=True, return_inverse=True, return_counts=True, dim=0)

        _, indices = torch.sort(counts, dim=0, descending=True)
        for ind in indices:
            ky = uniques[ind].view(1, -1)
            cnt = counts[ind]
            fake_prob = counts[ind] / gen_labels.shape[0]
            print(ky, cnt, "fake prob:", fake_prob)
            # print("fake:",cnt/generated_labels_full.shape[0], " true dist", true_dist_dict[tuple(ky.tolist())])
            dd = tuple(ky.tolist()[0])
            # loss = abs(counts[ind] / gen_labels.shape[0] - true_dist_dict[dd])
            result["obs_comb"].append(ky.tolist())
            result["loss"].append(100)
            result["prob"].append(fake_prob.item())
            # losses[tuple(ky.tolist()[0])] = loss.item()
            prob_dict[tuple(ky.tolist()[0])] = fake_prob.item()

            # print()

        # print(f"matched  {tot} out of {len(sample_dict[key])} proportion: {tot/len(sample_dict[key])}")

    chosen_samples = {"obs_comb": [], "prob": [], "loss": []}
    for ss in samples:
        for idx, jj in enumerate(result["obs_comb"]):

            l1 = ss.tolist()[0]
            l2 = jj[0]

            if tuple(l1) == tuple(l2):
                print(ss, result["prob"][idx], result["loss"][idx])
                chosen_samples["obs_comb"].append(l1)
                chosen_samples["prob"].append(result["prob"][idx])
                chosen_samples["loss"].append(result["loss"][idx])

    print("---")

    _, indices = torch.sort(torch.tensor(chosen_samples["prob"]), descending=True)

    l1 = [chosen_samples["obs_comb"][id] for id in indices]
    l2 = [chosen_samples["prob"][id] for id in indices]
    l3 = [chosen_samples["loss"][id] for id in indices]

    # l1,l2,l3= chosen_samples["obs_comb"], chosen_samples["prob"], chosen_samples["loss"]
    # chosen_samples["prob"], chosen_samples["obs_comb"], chosen_samples["loss"]  =  zip(*sorted(zip(l2, l1,l3)))

    for idx, _ in enumerate(l1):
        print(l1[idx], l2[idx], l3[idx])

    save_res = {"obs_comb": l1, "prob": l2, "loss": l3}

    ev_str= "".join(str(x) for x in evidence.values())
    in_str="".join(str(x) for x in intv_key.values())
    file_name = "prob_save_file_path"
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(save_res))

    return

def get_highest_cf_samples_for_image(Exp, label_generators):

    result = {"obs_comb": [], "prob":[], "loss": []}


    obs_vars=["X1", "X2","W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]

    feat = "feature"
    cfquery = Exp.cf_queries[0]

    if bool(set(cfquery["obs"]) & set(cur_mechs)) == False:
        return tvd_diff, kl_diff

    evidence_vars = [Exp.twin_map[lb] for lb in cfquery["evidence"][0].keys()]
    compare_Var = list(evidence_vars)  # getting the intervened variables
    query_str = getdoKey(compare_Var, dict({}))  # getting the scm saving file name
    obs_dist = get_intv_dist(Exp, compare_Var, dict({}), query_str)  # getting the obs distribution of intv variables

    final_tvd = 0
    final_kl = 0

    n_samples = Exp.Synthetic_Sample_Size

    evidence_list = [evidence for evidence in cfquery["evidence"]]
    all_posterior_label, all_posterior_latent, all_gumbel_noise = rejection_sampling_optimized(Exp, label_generators,
                                                                                               n_samples, evidence_list,
                                                                                               max_rejections=0,
                                                                                               warn=100)


    evidence_list=[{"X1p":0, "X2p":0}, {"X1p":0, "X2p":1}, {"X1p":0, "X2p":2},
                   {"X1p":1, "X2p":6}, {"X1p":1, "X2p":7}, {"X1p":1, "X2p":8}]

    intv_dict=[{0:[3,5]}, {1:[3,5]}, {2:[4,7]},
               {3:[1,3]}, {4:[1,3]}, {5:[6,7]}]



    kev = asKey(evidence)
    posterior_label, posterior_latent, gumbel_noise = all_posterior_label[kev], all_posterior_latent[kev], all_gumbel_noise[kev]

    cf_all_labels_dict = get_generated_labels(Exp, label_generators, posterior_label, posterior_latent,
                                              intv_key, cfquery["obs"], n_samples, gumbel_noise=gumbel_noise)
    cf_samples = map_dictfill_to_discrete(Exp, cf_all_labels_dict, cfquery["obs"])

    upd_dist = get_joint_distributions_from_samples(Exp, observed_var, samples, feature)

    # true_cf_dist = get_cf_dist(Exp, cfquery["obs"], intv_key, evidence, cfquery["expr"], load_dist=True)

    print(f"CF query done for evidence:{evidence}, intv_key: {intv_key} ")


    file_name ="prob_save_file_path"
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(result))

    return


Exp = Experiment("Exp1", set_trainGraph,
                 dist_thresh=0.15,
                 causal_hierarchy=2,
                 Temperature=1,
                 temp_min=0.1,
                 G_hid_dims=[256, 256],
                 D_hid_dims=[256, 256],
                 IMAGE_FILTERS=[128, 64, 32],
                 CRITIC_ITERATIONS=5,
                 LAMBDA_GP=1,
                 learning_rate=2 * 1e-4,
                 Synthetic_Sample_Size=10000,
                 intv_Sample_Size=10000,
                 batch_size=200,
                 features=["feature"],
                 noise_states=100,
                 latent_state=16,
                 Data_intervs=[{}],
                 num_epochs=200,
                 new_experiment=False,
                 obs_state=3
                 )


# get_expected_true_cf(Exp)
Exp.Synthetic_Sample_Size = 10000
Exp.intv_batch_size = Exp.batch_size




last_exp = "saved_result_file_path"
print(last_exp)
Exp.LOAD_MODEL_PATH = last_exp

load_which_models = {"X0": True, "X1": True, "X2": True,
                         "W0": True, "W1": True,
                         "Y0": True, "Y1": True}

# load_which_models = {"X1": False, "X2": False, "W": False, "Ydigit1": False, "Ydigit2": False, "Ycolor": False,
#                                                   "Ythick": False}

# cur_mechs = ["Ydigit1", "Ydigit2", "Ythick"]
# cur_mechs = ["X0", "X1", "X2", "W0","W1", "Y0", "Y1"] #fix this
cur_mechs = [["X0","W0","W1","Y0"], ["X1", "X2","W1","Y1"], ["X0", "X2", "W1", "Y0"], ["X0", "X1", "X2", "W0", "W1", "Y0", "Y1"]] #fix this

label_generators, optimizersMech = get_generators(Exp, load_which_models)


for gen in label_generators:
    label_generators[gen].eval()


with torch.no_grad():
    tvd_diff={}
    kl_diff={}
    for mechs in cur_mechs:
        tvd_diff, kl_diff, true_dist, fake_dist= get_observational_loss(Exp, mechs, label_generators, tvd_diff, kl_diff)
        print(tvd_diff)





# cfquery = Exp.cf_queries[0]
#
# for evidence in cfquery["evidence"]:
#     for intv_key in cfquery["intervs"]:
#         true_cf_dist = get_cf_dist(Exp, cfquery["obs"], intv_key, evidence, cfquery["expr"], load_dist=True)
#         print(f' intv:{intv_key}, evidence:{evidence}, dist:{true_cf_dist}')
#














